import gym
import numpy as np
from gym import spaces
from gym.envs.registration import register


class ImprovedHopperEnv(gym.Env):
    """
    An improved hopper environment that doesn't require MuJoCo.

    This environment simulates a 2D hopper with 3 joints and enhanced physics.
    The state space includes position, velocity, and joint angles.
    The action space controls the 3 joint torques.

    Stochasticity is added through:
    - action_noise_scale: Adds noise to the action before applying it
    - dynamics_noise_scale: Adds noise to the dynamical state (velocity, angles)
    - obs_noise_scale: Adds noise to the observations before returning them
    """

    def __init__(self, action_noise_scale=0.03, dynamics_noise_scale=0.02, obs_noise_scale=0.01):
        super(ImprovedHopperEnv, self).__init__()

        # 3-dimensional action space (3 joint torques)
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(3,), dtype=np.float32
        )

        # 11-dimensional observation space (similar to the real Hopper)
        # [x, z, theta1, theta2, theta3, x_vel, z_vel, theta1_vel, theta2_vel, theta3_vel, contact]
        self.observation_space = spaces.Box(
            low=-10.0, high=10.0, shape=(11,), dtype=np.float32
        )

        # Noise parameters - matching the style from other environments
        self.action_noise_scale = action_noise_scale
        self.dynamics_noise_scale = dynamics_noise_scale
        self.obs_noise_scale = obs_noise_scale

        # Physics parameters
        self.gravity = -9.8
        self.mass = 1.0
        self.timestep = 0.05
        self.joint_damping = 0.1

        # Enhanced physics parameters
        self.joint_coupling = 0.3  # How much joints affect each other
        self.velocity_coupling = 0.2  # How joint velocities contribute to forward motion
        self.momentum_factor = 0.9  # Conservation of momentum
        self.contact_force = 5.0  # Force when contacting ground
        self.leg_length = 0.5  # Length of leg for determining contact

        # State variables
        self.position = np.zeros(2)  # x, z
        self.velocity = np.zeros(2)  # x_vel, z_vel
        self.angles = np.zeros(3)  # 3 joint angles
        self.angle_vels = np.zeros(3)  # 3 joint angular velocities
        self.contact = 0  # Contact with ground (0 or 1)

        # Episode tracking
        self.steps = 0
        self.max_steps = 1000

        # Initialize state
        self.reset()

    def reset(self):
        """Reset the environment to an initial state."""
        # Initialize position (start slightly above ground)
        self.position = np.array([0.0, 0.5])

        # Initialize velocity
        self.velocity = np.zeros(2)

        # Initialize angles with small random values
        self.angles = np.random.uniform(-0.1, 0.1, size=3)

        # Initialize angular velocities
        self.angle_vels = np.zeros(3)

        # Not in contact initially
        self.contact = 0

        # Reset step counter
        self.steps = 0

        return self._get_obs()

    def _get_obs(self):
        """Get current observation with optional noise."""
        # Combine all state variables
        obs = np.concatenate([
            self.position,
            self.angles,
            self.velocity,
            self.angle_vels,
            [self.contact]
        ])

        # Add observation noise - consistent with other environments
        if self.obs_noise_scale > 0:
            # Generate specific noise for each observation element
            obs_noise = np.random.normal(0, self.obs_noise_scale, size=obs.shape)
            obs += obs_noise

        return obs.astype(np.float32)  # Ensure correct dtype

    def step(self, action):
        """Take a step with the given action."""
        self.steps += 1

        # Convert action to numpy array if it's not already
        action = np.array(action, dtype=np.float32)

        # Apply action noise - consistent with other environments
        if self.action_noise_scale > 0:
            action_noise = np.random.normal(0, self.action_noise_scale, size=action.shape)
            noisy_action = action + action_noise
            noisy_action = np.clip(noisy_action, -1.0, 1.0)
        else:
            noisy_action = action.copy()

        # Apply joint torques (with improved physics)
        joint_forces = noisy_action * 3.0  # Scale to reasonable torque

        # Joint coupling - joints affect adjacent joints
        coupled_forces = joint_forces.copy()
        for i in range(1, len(self.angles) - 1):
            coupled_forces[i] += self.joint_coupling * (joint_forces[i - 1] + joint_forces[i + 1])

        # Update angular velocities based on torques
        self.angle_vels += coupled_forces - self.joint_damping * self.angle_vels

        # Apply dynamics noise to the state - consistent with other environments
        if self.dynamics_noise_scale > 0:
            # Apply noise to angular velocities
            angle_vel_noise = np.random.normal(0, self.dynamics_noise_scale, size=self.angle_vels.shape)
            self.angle_vels += angle_vel_noise

            # Apply noise to linear velocities
            vel_noise = np.random.normal(0, self.dynamics_noise_scale, size=self.velocity.shape)
            self.velocity += vel_noise

        # Update joint angles
        self.angles += self.angle_vels * self.timestep
        self.angles = np.clip(self.angles, -1.0, 1.0)  # Limit joint angles

        # Enhanced leg dynamics - convert joint movements to forward motion
        # More complex model considering angle combinations
        forward_force = np.sum(coupled_forces * np.cos(self.angles))
        forward_force += self.velocity_coupling * np.sum(np.abs(self.angle_vels))  # Angular momentum contributes
        upward_force = np.sum(coupled_forces * np.sin(self.angles))

        # Check ground contact
        if self.position[1] <= self.leg_length * np.cos(np.mean(np.abs(self.angles))):
            self.contact = 1
            # Add contact forces - pushing against ground creates motion
            if upward_force < 0:  # Pushing down
                forward_force += self.contact_force * np.abs(upward_force) * 0.5  # Convert some downward to forward
                upward_force = max(0, upward_force)  # Ground reaction force

                # Add a small bouncing effect
                if self.velocity[1] < 0:
                    upward_force -= self.velocity[1] * self.mass  # Counteract downward velocity
        else:
            self.contact = 0

        # Update velocity with momentum conservation
        prev_velocity = self.velocity.copy()
        self.velocity[0] = self.momentum_factor * prev_velocity[0] + forward_force * self.timestep
        self.velocity[1] = self.momentum_factor * prev_velocity[1] + (upward_force + self.gravity) * self.timestep

        # Apply drag - more when in contact with ground
        air_drag = 0.95
        ground_drag = 0.6 if self.contact else air_drag
        self.velocity *= ground_drag

        # Update position
        self.position += self.velocity * self.timestep

        # Enforce ground constraint
        min_height = self.leg_length * 0.5  # Minimum height
        if self.position[1] < min_height:
            self.position[1] = min_height
            # Bounce when hitting ground
            if self.velocity[1] < 0:
                self.velocity[1] = -self.velocity[1] * 0.3  # 30% bounce

        # Calculate reward - enhanced version
        # 1. Forward velocity reward - scaled up
        forward_reward = 5.0 * self.velocity[0]

        # 2. Energy efficiency penalty (reduced)
        energy_penalty = 0.05 * np.sum(np.square(action))

        # 3. Staying alive bonus - increased
        alive_bonus = 2.0

        # 4. Height bonus - encourage getting off the ground
        height_bonus = self.position[1] * 2.0

        # 5. Falling penalty
        falling_penalty = 0 if self.position[1] > 0.1 else -5.0

        # 6. Movement bonus - encourage changing joint angles
        movement_bonus = 0.1 * np.sum(np.abs(self.angle_vels))

        reward = forward_reward - energy_penalty + alive_bonus + height_bonus + falling_penalty + movement_bonus

        # Check if done
        done = (self.position[1] < 0.1 and abs(self.velocity[1]) < 0.05) or self.steps >= self.max_steps

        return self._get_obs(), reward, done, {}

    def render(self, mode='human'):
        """Simple text rendering of hopper state."""
        if mode == 'human':
            print(f"Position: [{self.position[0]:.2f}, {self.position[1]:.2f}], "
                  f"Velocity: [{self.velocity[0]:.2f}, {self.velocity[1]:.2f}], "
                  f"Contact: {self.contact}")


# Register the environment
register(
    id="ImprovedHopper-v0",
    entry_point="improved_hopper:ImprovedHopperEnv",
    max_episode_steps=1000,
)